Skip to content

Initial version porting the eden configs over to the new evo2 recipe#1502

Open
jstjohn wants to merge 12 commits intomainfrom
jstjohn/evo2_llama_configs_and_savanna_convert
Open

Initial version porting the eden configs over to the new evo2 recipe#1502
jstjohn wants to merge 12 commits intomainfrom
jstjohn/evo2_llama_configs_and_savanna_convert

Conversation

@jstjohn
Copy link
Collaborator

@jstjohn jstjohn commented Mar 9, 2026

Description

This PR adds Eden (Llama 3.1) model support, Savanna/Vortex checkpoint converters, and a standardized model naming convention to the Megatron Bridge–based Evo2 recipe (bionemo-recipes/recipes/evo2_megatron/).

Eden (Llama 3.1) model support

  • New eden_provider.py defining EdenModelProvider and size-specific subclasses (eden_7b through eden_35b) that inherit from Llama31ModelProvider.
  • train.py now dispatches to gpt_forward_step for Eden models and automatically disables fp32_residual_connection (incompatible with standard TE LayerNormLinear layers — Hyena handles this via manual dtype casting, but GPT/Llama does not).
  • infer.py now initializes ProcessGroupCollection for non-Hyena providers (required by GPTModelProvider.provide()) and uses StaticInferenceContext instead of HyenaInferenceContext for Eden models. The flash_decode attribute is guarded to Hyena-only.
  • predict.py already worked architecture-agnostically via dynamic model loading; no changes required.

Checkpoint converters

  • savanna_to_mbridge.py — converts ARC Savanna .pt checkpoints (local or downloaded from Hugging Face via hf_hub_download) into MBridge distributed checkpoint format.
  • mbridge_to_vortex.py — exports MBridge checkpoints to ARC's single-file Vortex inference format, handling MLP weight splitting, Hyena filter pole/residue computation, and TE layernorm key remapping.
  • Both are registered as console scripts (evo2_convert_savanna_to_mbridge, evo2_export_mbridge_to_vortex).

Model naming convention

The previous model size keys (1b, 7b, 40b, 7b_arc_longcontext, …) were ambiguous — 7b referred to Striped Hyena while 7B referred to Llama. This PR replaces them with explicit, architecture-prefixed keys:

  • evo2_* for models matching public ARC checkpoints (e.g. evo2_1b_base, evo2_7b, evo2_40b_base). _base = 8K context, without it = 1M context.
  • striped_hyena_*_nv for NVIDIA-modified Hyena variants.
  • eden_* for Llama 3.1 variants.
  • Added evo2_20b config based on arcinstitute/savanna_evo2_20b.

Documentation updates

  • README.md — added model naming convention tables, Vortex export section with round-trip example, updated all CLI examples to new model keys.
  • checkpoint/README.md — updated --model-size documentation.
  • Both Jupyter notebooks (zeroshot_brca1.ipynb, fine-tuning-tutorial.ipynb) — updated MODEL_SIZE and --model-size references.

Usage

Training an Eden model:

torchrun --nproc-per-node 1 --no-python train_evo2 \
  --model-size eden_7b --num-layers 2 --max-steps 5 \
  --mock-data --seq-length 64 --mixed-precision-recipe bf16_mixed \
  --no-activation-checkpointing

Converting Savanna checkpoint to MBridge:

evo2_convert_savanna_to_mbridge \
  --savanna-ckpt-path arcinstitute/savanna_evo2_1b_base \
  --mbridge-ckpt-dir /tmp/mbridge_1b \
  --model-size evo2_1b_base \
  --tokenizer-path tokenizers/nucleotide_fast_tokenizer_256

Exporting MBridge to Vortex:

evo2_export_mbridge_to_vortex \
  --mbridge-ckpt-dir /tmp/mbridge_1b/iter_0000001 \
  --output-path /tmp/evo2_1b_vortex.pt \
  --model-size evo2_1b_base

Type of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Refactor
  • Documentation update
  • Other (please describe):

CI Pipeline Configuration

Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.

  • ciflow:skip - Skip all CI tests for this PR
  • ciflow:notebooks - Run Jupyter notebooks execution tests for bionemo2
  • ciflow:slow - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2
  • ciflow:all - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2.
  • ciflow:all-recipes - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes.

Unit tests marked as @pytest.mark.multi_gpu or @pytest.mark.distributed are not run in the PR pipeline.

For more details, see CONTRIBUTING

Note

By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.

Authorizing CI Runs

We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.

  • If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
    automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
  • If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
    /ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.

Triggering Code Rabbit AI Review

To trigger a code review from code rabbit, comment on a pull request with one of these commands:

See https://docs.coderabbit.ai/reference/review-commands for a full list of commands.

Pre-submit Checklist

  • I have tested these changes locally
  • I have updated the documentation accordingly
  • I have added/updated tests as needed
  • All existing tests pass successfully

Summary by CodeRabbit

  • New Features

    • Added Eden (Llama 3.1) model family support alongside existing Hyena models (11B–35B variants).
    • Added checkpoint conversion utilities: Savanna-to-MBridge and MBridge-to-Vortex exporters with CLI tools.
  • Documentation

    • Updated model naming convention with Evo2 prefixes (e.g., evo2_1b_base, evo2_7b).
    • Expanded documentation for checkpoint conversion workflows and available models.
  • Tests

    • Added comprehensive test coverage for model providers, checkpoint conversions, and Eden inference/prediction workflows.

Signed-off-by: John St. John <jstjohn@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 9, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7f018d6e-1c42-462a-9447-b2db182ebad2

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR introduces support for Eden (Llama 3.1) model variants alongside the existing Hyena SSM models in the Evo2 framework. New checkpoint conversion utilities enable Savanna-to-MBridge and MBridge-to-Vortex transformations, updated runtime scripts support model-type branching, and CLI entry points are added for checkpoint operations. Comprehensive test coverage validates model providers, roundtrip conversions, and inference/training workflows for both architectures.

Changes

Cohort / File(s) Summary
Documentation & CLI Configuration
README.md, pyproject.toml, checkpoint/README.md
Updated model naming conventions, command examples, and model-key references; added two new CLI scripts (evo2_convert_savanna_to_mbridge, evo2_export_mbridge_to_vortex) as entry points.
Model Provider Infrastructure
src/bionemo/evo2/models/eden_provider.py
New module introducing seven Eden model providers (Eden11B-Eden35B) as Llama 3.1 variants with architecture-specific hyperparameters, tokenizer patching, and public EDEN_MODEL_OPTIONS mapping.
Evo2 Provider Integration
src/bionemo/evo2/models/evo2_provider.py
Extended HYENA_MODEL_OPTIONS with Evo2 and striped variants, added Hyena20bModelProvider, introduced MODEL_OPTIONS merging Hyena and Eden, and added infer_model_type() utility for model classification.
Checkpoint Conversion Utilities
src/bionemo/evo2/utils/checkpoint/mbridge_to_vortex.py, savanna_to_mbridge.py, nemo2_to_mbridge.py
New exporters: mbridge_to_vortex converts Megatron Bridge DCP to Vortex format with layer-by-layer state mapping; savanna_to_mbridge converts Savanna PyTorch checkpoints to MBridge with multi-part download, state mapping, and metadata packaging; nemo2_to_mbridge updated to use unified MODEL_OPTIONS.
Training & Inference Runtime
src/bionemo/evo2/run/train.py, infer.py, predict.py
Train: added model-type branching to select gpt_forward_step (Eden) or hyena_forward_step (Hyena), conditional fp32 residual logic. Infer: introduced Hyena-specific inference context and flash-decode gating via is_hyena type check. Predict: reorganized imports to module scope.
Test Infrastructure
tests/bionemo/evo2/conftest.py
Added session-scoped mbridge_eden_checkpoint fixture for tiny Eden model training, updated model-size identifier in Nemo2-to-MBridge conversion, and improved environment cleanup between tests.
Inference & Prediction Tests
tests/bionemo/evo2/run/test_infer.py, test_predict.py
Added Eden-specific inference and prediction tests with determinism validation, top-k/temperature sampling, and phylogenetic prompt handling; replaced hardcoded model sizes with Evo2 identifiers.
Training Tests
tests/bionemo/evo2/run/test_train.py
Extended with Eden fine-tuning end-to-end tests, updated model-size parameters, and added TensorBoard log verification; uses bionemo_load for checkpoint retrieval.
Roundtrip Validation Tests
tests/bionemo/evo2/test_checkpoint_roundtrip.py, test_eden_llama_roundtrip.py, _eden_roundtrip_helper.py
New test modules validating Savanna→MBridge→Vortex state-dict key equality, tensor shapes, and values; Eden HF export/import roundtrip with weight and prediction equality checks; helper utilities for distributed MBridge↔HF conversion.
Provider & Model Tests
tests/bionemo/evo2/test_evo2.py, test_model_providers.py
Updated model-size identifiers across test matrices; new comprehensive provider validation tests checking model-key presence, architecture parameters, tokenizer patching, state-dict mapping for Savanna/Vortex formats, and infer_model_type() behavior.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant Savanna as Savanna<br/>Checkpoint
    participant Conv1 as savanna_to_mbridge<br/>Converter
    participant MBridge as Megatron Bridge<br/>DCP Checkpoint
    participant Conv2 as mbridge_to_vortex<br/>Exporter
    participant Vortex as Vortex<br/>Format

    User->>Conv1: savanna_to_mbridge()<br/>(savanna_path, model_size)
    Conv1->>Savanna: load_savanna_state_dict()
    Savanna-->>Conv1: state_dict
    Conv1->>Conv1: select model_provider<br/>from MODEL_OPTIONS
    Conv1->>Conv1: savanna_to_mbridge_state_dict()<br/>(apply pattern mapping)
    Conv1->>Conv1: package_mbridge_checkpoint()<br/>(write DCP structure)
    Conv1-->>MBridge: checkpoint written
    MBridge-->>User: output_path

    User->>Conv2: mbridge_to_vortex()<br/>(mbridge_dir, model_size)
    Conv2->>MBridge: load_mbridge_state_dict()
    MBridge-->>Conv2: state_dict
    Conv2->>Conv2: select HyenaModelProvider<br/>from HYENA_MODEL_OPTIONS
    Conv2->>Conv2: mbridge_to_vortex_state_dict()<br/>(per-layer conversion:<br/>embedding, decoder blocks,<br/>final norm)
    Conv2->>Conv2: _convert_hyena_layer()<br/>or _convert_attention_layer()
    Conv2->>Conv2: _convert_mlp()
    Conv2->>Conv2: _build_vortex_config()
    Conv2-->>Vortex: .pt + config.json
    Vortex-->>User: export complete
Loading
sequenceDiagram
    actor User
    participant Train as train.py<br/>Pretraining
    participant Infer as infer_model_type()<br/>Classifier
    participant Provider as Model Provider<br/>(Hyena or Eden)
    participant FwdStep as Forward Step<br/>Function

    User->>Train: launch with --model-size
    Train->>Infer: infer_model_type(model_size)
    alt model_size in HYENA_MODEL_OPTIONS
        Infer-->>Train: "hyena"
        Train->>Provider: HyenaModelProvider
        Train->>FwdStep: select hyena_forward_step
    else model_size in EDEN_MODEL_OPTIONS
        Infer-->>Train: "eden"
        Train->>Provider: EdenModelProvider
        Train->>FwdStep: select gpt_forward_step
    end
    alt model_type != "hyena" && fp32_residual_connection
        Train->>Train: disable fp32_residual_connection
    end
    Train->>FwdStep: launch pretraining<br/>with forward_step_fn
    FwdStep-->>Train: loss, gradients
    Train-->>User: training complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 With Eden's grace and Hyena's stride,
New checkpoints convert far and wide!
From Savanna to Vortex we smoothly sail,
Model-branching paths, none shall fail! ✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: porting Eden (Llama 3.1) configs to the Evo2 recipe, which is the primary objective of this PR.
Description check ✅ Passed The PR description is comprehensive and covers all required template sections: detailed description of changes, usage examples, type of changes (marked), and pre-submit checklist completed.
Docstring Coverage ✅ Passed Docstring coverage is 97.83% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jstjohn/evo2_llama_configs_and_savanna_convert

Comment @coderabbitai help to get the list of available commands and usage tips.

jstjohn added 3 commits March 9, 2026 20:45
Signed-off-by: John St. John <jstjohn@nvidia.com>
…tions

Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
@jstjohn
Copy link
Collaborator Author

jstjohn commented Mar 10, 2026

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 10, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (4)
bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py (1)

714-805: Use the source infer.py launch path for the Eden subprocesses.

This file already introduced _infer_script_path() and the PYTHONPATH prepend so local infer.py fixes are exercised without reinstalling. These new Eden cases still launch -m bionemo.evo2.run.infer, so a non-editable test environment can end up validating the installed package instead of this PR.

Suggested direction
-        "-m",
-        "bionemo.evo2.run.infer",
+        str(_infer_script_path()),
         "--ckpt-dir",
         str(mbridge_eden_checkpoint_path),
         ...
     ]

     env = copy.deepcopy(PRETEST_ENV)
+    src_dir = str(_recipe_root() / "src")
+    env["PYTHONPATH"] = src_dir + os.pathsep + env.get("PYTHONPATH", "")

Apply the same launch pattern in test_infer_eden_deterministic().

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py`
around lines 714 - 805, The Eden tests are launching the installed module (-m
bionemo.evo2.run.infer) instead of the local infer.py; update
test_infer_eden_deterministic to use the same launch pattern as
test_infer_eden_runs: replace the "-m bionemo.evo2.run.infer" style invocation
with the script path from _infer_script_path() (use that function to build the
cmd entry) and ensure the env prepends PYTHONPATH the same way (reuse
PRETEST_ENV modification logic used in test_infer_eden_runs) so the local source
infer.py is executed; reference test_infer_eden_deterministic,
test_infer_eden_runs, _infer_script_path, and PRETEST_ENV to locate where to
apply the changes.
bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py (1)

63-63: Consider adding weights_only=True or explicit weights_only=False to torch.load.

torch.load without weights_only will default to False and emit a deprecation warning in recent PyTorch versions. Since these are locally-generated prediction files (tensors only), weights_only=True should work and is safer.

Suggested fix
-    preds = [torch.load(pf) for pf in pred_files]
+    preds = [torch.load(pf, weights_only=True) for pf in pred_files]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py`
at line 63, The list comprehension using torch.load(pf) to build preds should
pass an explicit weights_only argument to avoid the deprecation warning and
ensure correct behavior for tensor-only prediction files; update the
comprehension that references pred_files and preds to call torch.load(pf,
weights_only=True) (or weights_only=False if non-weight objects are expected) so
loading is explicit and future-proof.
bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py (2)

132-132: weights_only=False is necessary but has security implications.

This is required for loading Savanna checkpoints that may contain non-tensor data, but it can execute arbitrary code via pickle when loading untrusted files. Consider adding a comment noting this:

Suggested documentation
-    raw = torch.load(str(path), map_location="cpu", weights_only=False)
+    # Note: weights_only=False is required for Savanna checkpoints containing custom objects.
+    # Only load checkpoints from trusted sources.
+    raw = torch.load(str(path), map_location="cpu", weights_only=False)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py`
at line 132, Add a brief inline comment next to the torch.load call (the line
setting raw = torch.load(str(path), map_location="cpu", weights_only=False))
explaining that weights_only=False is required to load Savanna checkpoints
containing non-tensor data but is unsafe for untrusted files because it uses
pickle and can execute arbitrary code; state that callers must ensure the path
is trusted (or sanitize/validate inputs) before loading.

95-96: Bare except Exception is too broad.

This catches all exceptions including unexpected ones like KeyboardInterrupt (actually BaseException, but Exception still catches many things). Consider catching specific HuggingFace Hub exceptions:

Suggested fix
-    except Exception:
+    except (huggingface_hub.errors.EntryNotFoundError, huggingface_hub.errors.RepositoryNotFoundError):
         logger.warning(f"Single-file download failed for {repo_id}, trying multi-part shards...")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py`
around lines 95 - 96, The bare except in the single-file download block (the
except Exception: that logs "Single-file download failed for {repo_id}, trying
multi-part shards...") is too broad; replace it by catching specific HF and
network-related exceptions (for example
huggingface_hub.exceptions.RepositoryNotFoundError,
huggingface_hub.utils.entry_not_found_error or RevisionNotFoundError, and
requests.exceptions.HTTPError/ConnectionError) and bind the exception (except
(RepositoryNotFoundError, RevisionNotFoundError, HTTPError, ConnectionError) as
e:) so you can log the actual error (include exc_info or str(e)) while allowing
other unexpected exceptions to propagate (re-raise or don’t catch them). Locate
the except block around the single-file download attempt and update the except
clause and logger call accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@bionemo-recipes/recipes/evo2_megatron/README.md`:
- Around line 157-162: Add a language identifier (bash) to the fenced code
blocks containing the shell commands (e.g., the blocks that show
evo2_export_mbridge_to_vortex and evo2_convert_savanna_to_mbridge) so
markdownlint stops warning; locate the backtick fences around those command
examples and change the opening fence from ``` to ```bash for each occurrence
(including the second block around the evo2_convert_savanna_to_mbridge example).

In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/eden_provider.py`:
- Around line 130-139: The patch_eden_tokenizer function is defined but never
used at runtime; either remove this dead function and its export (and update the
unit test to use the runtime implementation) or integrate it into the recipes
tokenizer flow by importing and calling patch_eden_tokenizer immediately after
the tokenizer is constructed in predict.py (so the tokenizer uses BOS=1, EOS=2,
SEP=3, PAD=0); also ensure any exported symbol lists are updated to remove the
orphaned function if you delete it and avoid duplicating functionality already
present in the other package-level patch implementation.

In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py`:
- Around line 1115-1132: The bug: MODEL_OPTIONS is built as
{**HYENA_MODEL_OPTIONS, **EDEN_MODEL_OPTIONS} which lets EDEN override HYENA on
key collision, but infer_model_type checks HYENA first causing inconsistent
behavior; fix by adding a runtime collision check after constructing
MODEL_OPTIONS that computes collisions = set(HYENA_MODEL_OPTIONS) &
set(EDEN_MODEL_OPTIONS) and either raise a clear ValueError (or log and resolve
to a chosen precedence) if collisions is non-empty, and update infer_model_type
to rely on MODEL_OPTIONS (or document the chosen precedence) so behavior is
consistent; also update infer_model_type's docstring to include an Args section
describing the model_size parameter.
- Around line 656-699: Hyena20bModelProvider defines an incorrect/unused
attribute short_conv_len and an orphan hyena_out_proj_bias; remove
short_conv_len (it duplicates/typoed counterpart hyena_short_conv_len) and
delete hyena_out_proj_bias unless you add a corresponding field in HyenaConfig
and wire it into the model code; update Hyena20bModelProvider by removing the
short_conv_len and hyena_out_proj_bias declarations (or rename short_conv_len to
hyena_short_conv_len only if HyenaConfig lacks that field and you also add it
there).

In `@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py`:
- Around line 334-336: You're directly mutating the private attribute
model_provider._pg_collection with
ProcessGroupCollection.use_mpu_process_groups(), which is fragile; instead
expose and use a public API or setter on the provider (e.g., add or call a
method like set_process_group_collection or a constructor/init parameter on the
ModelProvider class) so non-Hyena models can be configured without touching
internals—update the provider implementation to accept and store the
ProcessGroupCollection via that public method and replace the direct assignment
at the call site with the new setter or init call, keeping Hyena models'
internal behavior unchanged.

In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/mbridge_to_vortex.py`:
- Around line 133-167: The export currently silently omits required tensors
(e.g., embedding.word_embeddings.weight, decoder.final_norm.weight and per-layer
weights produced by _convert_hyena_layer, _convert_attention_layer, and
_convert_mlp) so add a validation pass after the loop that defines the mandatory
target keys (embedding_layer.weight, unembed.weight, norm.scale and all expected
per-layer keys derived from prefix/block_prefix and the pattern) and check their
presence in the resulting vortex_sd (or mbridge_state_dict if conversions expect
source keys); collect missing keys into a list and raise a descriptive exception
listing layer and key names (including references to the layer index and symbol)
before writing the .pt/config.json to fail fast on bad --model-size or --no-te
choices.
- Around line 48-56: The current logic assumes mbridge_ckpt_dir is the
checkpoint root and fails if the user passes an iter_* directory; modify the
resolver around latest_file/iter_dir so that if mbridge_ckpt_dir.name matches
the iter_* pattern (e.g., startswith "iter_" or matches r"^iter_\d+$") you treat
mbridge_ckpt_dir itself as the iter_dir; otherwise keep the existing flow (check
for latest_checkpointed_iteration.txt, parse iteration into iter_{:07d}, or
fallback to glob("iter_*")). Update uses of iteration/iter_dirs to reflect this
early-path selection so valid direct iter_* paths are accepted.

In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_checkpoint_roundtrip.py`:
- Around line 39-79: The fixtures download mutable HuggingFace checkpoints and
allow unsafe pickle loading; update savanna_checkpoint_path and
vortex_reference_path to pass explicit immutable commit SHAs via the revision=
parameter in their hf_hub_download(...) calls (use the specific commit SHA for
SAVANNA_1B_REPO and VORTEX_1B_REPO to pin the golden data), and modify
vortex_reference_sd to load the reference safely by calling torch.load(...,
map_location="cpu", weights_only=True) (or prefer safetensors if a .safetensors
artifact exists) so remote .pt files cannot execute pickle code.

In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py`:
- Around line 148-186: The test currently computes original_preds via
_run_predict(eden_ckpt, ...) but then only compares original_hf and
reimported_hf logits; change the test to run _run_predict on the roundtripped HF
checkpoint (use hf_reimported_dir) to produce hf_preds and then compare
original_preds["log_probs_seqs"] to hf_preds["log_probs_seqs"] (or the
equivalent per-token log-prob key) using a numeric assert (e.g.,
torch.testing.assert_close or numpy.testing.assert_allclose) so the comparison
actually verifies the roundtrip predictions; locate and update the block that
currently creates original_hf/reimported_hf and the final
torch.testing.assert_close to instead call _run_predict for hf_reimported_dir
(or both HF and eden if you want) and perform the assertion on the
"log_probs_seqs" entries.

---

Nitpick comments:
In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py`:
- Line 132: Add a brief inline comment next to the torch.load call (the line
setting raw = torch.load(str(path), map_location="cpu", weights_only=False))
explaining that weights_only=False is required to load Savanna checkpoints
containing non-tensor data but is unsafe for untrusted files because it uses
pickle and can execute arbitrary code; state that callers must ensure the path
is trusted (or sanitize/validate inputs) before loading.
- Around line 95-96: The bare except in the single-file download block (the
except Exception: that logs "Single-file download failed for {repo_id}, trying
multi-part shards...") is too broad; replace it by catching specific HF and
network-related exceptions (for example
huggingface_hub.exceptions.RepositoryNotFoundError,
huggingface_hub.utils.entry_not_found_error or RevisionNotFoundError, and
requests.exceptions.HTTPError/ConnectionError) and bind the exception (except
(RepositoryNotFoundError, RevisionNotFoundError, HTTPError, ConnectionError) as
e:) so you can log the actual error (include exc_info or str(e)) while allowing
other unexpected exceptions to propagate (re-raise or don’t catch them). Locate
the except block around the single-file download attempt and update the except
clause and logger call accordingly.

In `@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py`:
- Around line 714-805: The Eden tests are launching the installed module (-m
bionemo.evo2.run.infer) instead of the local infer.py; update
test_infer_eden_deterministic to use the same launch pattern as
test_infer_eden_runs: replace the "-m bionemo.evo2.run.infer" style invocation
with the script path from _infer_script_path() (use that function to build the
cmd entry) and ensure the env prepends PYTHONPATH the same way (reuse
PRETEST_ENV modification logic used in test_infer_eden_runs) so the local source
infer.py is executed; reference test_infer_eden_deterministic,
test_infer_eden_runs, _infer_script_path, and PRETEST_ENV to locate where to
apply the changes.

In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py`:
- Line 63: The list comprehension using torch.load(pf) to build preds should
pass an explicit weights_only argument to avoid the deprecation warning and
ensure correct behavior for tensor-only prediction files; update the
comprehension that references pred_files and preds to call torch.load(pf,
weights_only=True) (or weights_only=False if non-weight objects are expected) so
loading is explicit and future-proof.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 77786c27-73d0-4f42-b35a-c5b97f7a217b

📥 Commits

Reviewing files that changed from the base of the PR and between 470e10d and 7cd8ede.

📒 Files selected for processing (22)
  • bionemo-recipes/recipes/evo2_megatron/README.md
  • bionemo-recipes/recipes/evo2_megatron/examples/fine-tuning-tutorial.ipynb
  • bionemo-recipes/recipes/evo2_megatron/examples/zeroshot_brca1.ipynb
  • bionemo-recipes/recipes/evo2_megatron/pyproject.toml
  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/eden_provider.py
  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py
  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py
  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py
  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py
  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/README.md
  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/mbridge_to_vortex.py
  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/nemo2_to_mbridge.py
  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py
  • bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/_eden_roundtrip_helper.py
  • bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py
  • bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py
  • bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py
  • bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_train.py
  • bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_checkpoint_roundtrip.py
  • bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py
  • bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py
  • bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py

Comment on lines +157 to +162
```
evo2_export_mbridge_to_vortex \
--mbridge-ckpt-dir /path/to/mbridge/iter_0000001 \
--output-path /path/to/output/model_vortex.pt \
--model-size evo2_1b_base
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a language tag to these new fenced command blocks.

markdownlint will keep warning on both fences until they use a language identifier. bash fits the surrounding shell examples.

💡 Minimal fix
-```
+```bash
 evo2_export_mbridge_to_vortex \
   --mbridge-ckpt-dir /path/to/mbridge/iter_0000001 \
   --output-path /path/to/output/model_vortex.pt \
   --model-size evo2_1b_base

@@
- +bash

Step 1: Savanna -> MBridge

evo2_convert_savanna_to_mbridge
--savanna-ckpt-path arcinstitute/savanna_evo2_1b_base
--mbridge-ckpt-dir /tmp/mbridge_1b
--model-size evo2_1b_base \

</details>


Also applies to: 181-194

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 markdownlint-cli2 (0.21.0)</summary>

[warning] 157-157: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

</details>

</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against the current code and only fix it if needed.

In @bionemo-recipes/recipes/evo2_megatron/README.md around lines 157 - 162, Add
a language identifier (bash) to the fenced code blocks containing the shell
commands (e.g., the blocks that show evo2_export_mbridge_to_vortex and
evo2_convert_savanna_to_mbridge) so markdownlint stops warning; locate the
backtick fences around those command examples and change the opening fence from
tobash for each occurrence (including the second block around the
evo2_convert_savanna_to_mbridge example).


</details>

<!-- fingerprinting:phantom:medusa:grasshopper -->

<!-- This is an auto-generated comment by CodeRabbit -->

Comment on lines +656 to +699
@dataclass
class Hyena20bModelProvider(HyenaModelProvider):
"""Config matching the Evo2 20B 1M context model (arcinstitute/evo2_20b).

Source: evo2/configs/evo2-20b-1m.yml from ARC's evo2 repo.
Layer pattern derived from: hcs=[0,4,7,11,14,18,21], hcm=[1,5,8,12,15,19,22],
hcl=[2,6,9,13,16,20,23], attn=[3,10,17].
"""

hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH"
num_layers: int = 24
seq_length: int = 1_048_576
hidden_size: int = 8192
num_groups_hyena: int = 8192
num_groups_hyena_medium: int = 512
num_groups_hyena_short: int = 512
make_vocab_size_divisible_by: int = 8
tokenizer_library: str = "byte-level"
mapping_type: str = "base"
ffn_hidden_size: int = 22528
gated_linear_unit: bool = True
num_attention_heads: int = 64
use_cpu_initialization: bool = False
hidden_dropout: float = 0.0
attention_dropout: float = 0.0
params_dtype: torch.dtype = torch.bfloat16
normalization: str = "RMSNorm"
add_qkv_bias: bool = False
add_bias_linear: bool = False
layernorm_epsilon: float = 1e-6
recompute_granularity: str = "full"
recompute_method: str = "uniform"
recompute_num_layers: int = 4
hyena_init_method: str = "small_init"
hyena_output_layer_init_method: str = "wang_init"
hyena_filter_no_wd: bool = True
rotary_base: int = 1_000_000
seq_len_interpolation_factor: float = 128
hyena_medium_conv_len: int = 128
short_conv_len: int = 7
hyena_short_conv_len: int = 3
add_attn_proj_bias: bool = True
hyena_out_proj_bias: bool = True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if these attributes are used in the HyenaModel or related code
rg -n "hyena_medium_conv_len|short_conv_len|hyena_short_conv_len|hyena_out_proj_bias" --type py

Repository: NVIDIA/bionemo-framework

Length of output: 4294


🏁 Script executed:

# Check the HyenaModelProvider parent class definition
rg -n "class HyenaModelProvider" -A 20 --type py

Repository: NVIDIA/bionemo-framework

Length of output: 2633


🏁 Script executed:

# Check how these attributes are used in the provide() method
sed -n '360,405p' bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py

Repository: NVIDIA/bionemo-framework

Length of output: 2802


🏁 Script executed:

# Check HyenaConfig definition to see parent attributes
cat -n bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_config.py | head -150

Repository: NVIDIA/bionemo-framework

Length of output: 5210


🏁 Script executed:

# Check _HyenaConfigForFlops and how it's initialized in provide()
sed -n '340,370p' bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py

Repository: NVIDIA/bionemo-framework

Length of output: 1868


🏁 Script executed:

# Search for hyena_out_proj_bias definition and usage
rg -n "hyena_out_proj_bias" -B 2 -A 2 --type py

Repository: NVIDIA/bionemo-framework

Length of output: 591


🏁 Script executed:

# Check if short_conv_len is referenced or if it should be something else
rg -n "short_conv_len" --type py

Repository: NVIDIA/bionemo-framework

Length of output: 2270


Critical naming error and unused attribute in Hyena20bModelProvider.

Line 695 defines short_conv_len: int = 7, which appears nowhere else in the codebase and conflicts with the consistent naming pattern used throughout (e.g., hyena_short_conv_len at line 696 and in HyenaConfig). This attribute is likely a naming error and should either be removed or renamed to match the standard attribute name.

Additionally, line 698 defines hyena_out_proj_bias: bool = True, which has no corresponding definition in HyenaConfig and is never referenced in any model code. This orphaned attribute should be removed unless it serves a documented purpose.

The attributes hyena_medium_conv_len and hyena_short_conv_len are valid overrides of HyenaConfig defaults and are used in the model's FLOPs calculation, but the above two attributes need clarification or removal.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py`
around lines 656 - 699, Hyena20bModelProvider defines an incorrect/unused
attribute short_conv_len and an orphan hyena_out_proj_bias; remove
short_conv_len (it duplicates/typoed counterpart hyena_short_conv_len) and
delete hyena_out_proj_bias unless you add a corresponding field in HyenaConfig
and wire it into the model code; update Hyena20bModelProvider by removing the
short_conv_len and hyena_out_proj_bias declarations (or rename short_conv_len to
hyena_short_conv_len only if HyenaConfig lacks that field and you also add it
there).

Comment on lines +1115 to +1132
MODEL_OPTIONS: dict[str, object] = {**HYENA_MODEL_OPTIONS, **EDEN_MODEL_OPTIONS}


def infer_model_type(model_size: str) -> str:
"""Infer the model architecture type from the model size key.

Returns:
"hyena" if the key is in HYENA_MODEL_OPTIONS, "eden" if in EDEN_MODEL_OPTIONS.

Raises:
ValueError: If the key is not found in any model options dict.
"""
if model_size in HYENA_MODEL_OPTIONS:
return "hyena"
elif model_size in EDEN_MODEL_OPTIONS:
return "eden"
else:
raise ValueError(f"Unknown model size: {model_size!r}. Valid options: {sorted(MODEL_OPTIONS.keys())}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential inconsistency if key collision occurs between HYENA and EDEN options.

MODEL_OPTIONS is created via {**HYENA_MODEL_OPTIONS, **EDEN_MODEL_OPTIONS}, so if a key exists in both, Eden's value wins. However, infer_model_type checks HYENA_MODEL_OPTIONS first, so it would return "hyena" for a colliding key while MODEL_OPTIONS[key] returns the Eden provider.

Consider adding a runtime assertion to detect collisions:

Suggested collision check
 MODEL_OPTIONS: dict[str, object] = {**HYENA_MODEL_OPTIONS, **EDEN_MODEL_OPTIONS}
+
+# Ensure no key collisions between Hyena and Eden options
+_colliding_keys = set(HYENA_MODEL_OPTIONS.keys()) & set(EDEN_MODEL_OPTIONS.keys())
+if _colliding_keys:
+    raise ValueError(f"Key collision between HYENA and EDEN model options: {_colliding_keys}")

Also, the docstring for infer_model_type is missing the Args section per Google-style convention.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py`
around lines 1115 - 1132, The bug: MODEL_OPTIONS is built as
{**HYENA_MODEL_OPTIONS, **EDEN_MODEL_OPTIONS} which lets EDEN override HYENA on
key collision, but infer_model_type checks HYENA first causing inconsistent
behavior; fix by adding a runtime collision check after constructing
MODEL_OPTIONS that computes collisions = set(HYENA_MODEL_OPTIONS) &
set(EDEN_MODEL_OPTIONS) and either raise a clear ValueError (or log and resolve
to a chosen precedence) if collisions is non-empty, and update infer_model_type
to rely on MODEL_OPTIONS (or document the chosen precedence) so behavior is
consistent; also update infer_model_type's docstring to include an Args section
describing the model_size parameter.

Comment on lines +334 to +336
if not is_hyena:
model_provider._pg_collection = ProcessGroupCollection.use_mpu_process_groups()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check how _pg_collection is used in the model provider hierarchy
rg -n "_pg_collection" --type py -C 3

Repository: NVIDIA/bionemo-framework

Length of output: 3544


🏁 Script executed:

# Find the model_provider class definition and type
rg -n "model_provider\s*=" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py -B 3 -A 1 | head -50

Repository: NVIDIA/bionemo-framework

Length of output: 364


🏁 Script executed:

# Check the model_provider class definition
rg -n "class.*Provider" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/ -A 10 | grep -A 10 "_pg_collection"

Repository: NVIDIA/bionemo-framework

Length of output: 50


🏁 Script executed:

# Look for any setter method or property for _pg_collection
rg -n "def.*pg_collection|_pg_collection\s*=" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/ -B 2 -A 2

Repository: NVIDIA/bionemo-framework

Length of output: 50


🏁 Script executed:

# Find the model_provider class - check evo2_provider.py for class definition
cat bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py | head -100

Repository: NVIDIA/bionemo-framework

Length of output: 4585


🏁 Script executed:

# Check for _pg_collection initialization in __init__ or class attributes
rg -n "_pg_collection" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py -B 5 -A 5

Repository: NVIDIA/bionemo-framework

Length of output: 1161


🏁 Script executed:

# Look for how ProcessGroupCollection is used - find imports and usage
rg -n "ProcessGroupCollection" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py -B 5 -A 5

Repository: NVIDIA/bionemo-framework

Length of output: 1140


🏁 Script executed:

# Check the context around finalize() call - what does it do?
rg -n "def finalize" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/ -A 15

Repository: NVIDIA/bionemo-framework

Length of output: 50


🏁 Script executed:

# Search for any public setter methods or properties related to pg_collection in the provider
rg -n "def.*pg_collection|@property" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py -B 2 -A 5

Repository: NVIDIA/bionemo-framework

Length of output: 492


🏁 Script executed:

# Check if _pg_collection is accessed anywhere else or how it's used after being set
rg -n "\._pg_collection" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py -B 3 -A 3

Repository: NVIDIA/bionemo-framework

Length of output: 292


🏁 Script executed:

# Search for any comments or TODOs around this pattern
rg -n "pg_collection|process.*group" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py -B 5 -A 5 | head -80

Repository: NVIDIA/bionemo-framework

Length of output: 828


🏁 Script executed:

# Check if there are any comments in the test helper file that explain why this is needed
rg -n "pg_collection" bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/_eden_roundtrip_helper.py -B 10 -A 5

Repository: NVIDIA/bionemo-framework

Length of output: 663


🏁 Script executed:

# Look for use_mpu_process_groups method definition or documentation
rg -n "use_mpu_process_groups|ProcessGroupCollection" bionemo-recipes/recipes/evo2_megatron/src/ -B 2 -A 10 | head -100

Repository: NVIDIA/bionemo-framework

Length of output: 13059


Consider using a public API or setter method for _pg_collection configuration.

The direct assignment to model_provider._pg_collection works but sets a private attribute, which is fragile if the base class implementation changes. Unlike Hyena models that configure process groups internally through component constructors, non-Hyena models require this external setup. Consider whether a public setter method or initialization parameter would make this pattern more maintainable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py` around
lines 334 - 336, You're directly mutating the private attribute
model_provider._pg_collection with
ProcessGroupCollection.use_mpu_process_groups(), which is fragile; instead
expose and use a public API or setter on the provider (e.g., add or call a
method like set_process_group_collection or a constructor/init parameter on the
ModelProvider class) so non-Hyena models can be configured without touching
internals—update the provider implementation to accept and store the
ProcessGroupCollection via that public method and replace the direct assignment
at the call site with the new setter or init call, keeping Hyena models'
internal behavior unchanged.

Comment on lines +48 to +56
latest_file = mbridge_ckpt_dir / "latest_checkpointed_iteration.txt"
if latest_file.exists():
iteration = latest_file.read_text().strip()
iter_dir = mbridge_ckpt_dir / f"iter_{int(iteration):07d}"
else:
iter_dirs = sorted(mbridge_ckpt_dir.glob("iter_*"))
if not iter_dirs:
raise FileNotFoundError(f"No iter_* directories in {mbridge_ckpt_dir}")
iter_dir = iter_dirs[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle direct iter_* checkpoint paths here.

This resolver only works when --mbridge-ckpt-dir points at the checkpoint root. The new docs in this PR also show passing .../iter_0000001 directly, and that currently falls through to glob("iter_*") on the iteration directory itself and raises FileNotFoundError even though the checkpoint is valid.

🛠️ Possible fix
 def load_mbridge_state_dict(mbridge_ckpt_dir: Path) -> dict[str, torch.Tensor]:
@@
-    latest_file = mbridge_ckpt_dir / "latest_checkpointed_iteration.txt"
-    if latest_file.exists():
-        iteration = latest_file.read_text().strip()
-        iter_dir = mbridge_ckpt_dir / f"iter_{int(iteration):07d}"
-    else:
-        iter_dirs = sorted(mbridge_ckpt_dir.glob("iter_*"))
-        if not iter_dirs:
-            raise FileNotFoundError(f"No iter_* directories in {mbridge_ckpt_dir}")
-        iter_dir = iter_dirs[-1]
+    if mbridge_ckpt_dir.name.startswith("iter_"):
+        iter_dir = mbridge_ckpt_dir
+    else:
+        latest_file = mbridge_ckpt_dir / "latest_checkpointed_iteration.txt"
+        if latest_file.exists():
+            iteration = latest_file.read_text().strip()
+            iter_dir = mbridge_ckpt_dir / f"iter_{int(iteration):07d}"
+        else:
+            iter_dirs = sorted(mbridge_ckpt_dir.glob("iter_*"))
+            if not iter_dirs:
+                raise FileNotFoundError(f"No iter_* directories in {mbridge_ckpt_dir}")
+            iter_dir = iter_dirs[-1]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
latest_file = mbridge_ckpt_dir / "latest_checkpointed_iteration.txt"
if latest_file.exists():
iteration = latest_file.read_text().strip()
iter_dir = mbridge_ckpt_dir / f"iter_{int(iteration):07d}"
else:
iter_dirs = sorted(mbridge_ckpt_dir.glob("iter_*"))
if not iter_dirs:
raise FileNotFoundError(f"No iter_* directories in {mbridge_ckpt_dir}")
iter_dir = iter_dirs[-1]
if mbridge_ckpt_dir.name.startswith("iter_"):
iter_dir = mbridge_ckpt_dir
else:
latest_file = mbridge_ckpt_dir / "latest_checkpointed_iteration.txt"
if latest_file.exists():
iteration = latest_file.read_text().strip()
iter_dir = mbridge_ckpt_dir / f"iter_{int(iteration):07d}"
else:
iter_dirs = sorted(mbridge_ckpt_dir.glob("iter_*"))
if not iter_dirs:
raise FileNotFoundError(f"No iter_* directories in {mbridge_ckpt_dir}")
iter_dir = iter_dirs[-1]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/mbridge_to_vortex.py`
around lines 48 - 56, The current logic assumes mbridge_ckpt_dir is the
checkpoint root and fails if the user passes an iter_* directory; modify the
resolver around latest_file/iter_dir so that if mbridge_ckpt_dir.name matches
the iter_* pattern (e.g., startswith "iter_" or matches r"^iter_\d+$") you treat
mbridge_ckpt_dir itself as the iter_dir; otherwise keep the existing flow (check
for latest_checkpointed_iteration.txt, parse iteration into iter_{:07d}, or
fallback to glob("iter_*")). Update uses of iteration/iter_dirs to reflect this
early-path selection so valid direct iter_* paths are accepted.

Comment on lines +133 to +167
embed_key = "embedding.word_embeddings.weight"
if embed_key in mbridge_state_dict:
vortex_sd["embedding_layer.weight"] = mbridge_state_dict[embed_key]
vortex_sd["unembed.weight"] = mbridge_state_dict[embed_key]

for layer_idx, symbol in enumerate(pattern):
prefix = f"decoder.layers.{layer_idx}"
block_prefix = f"blocks.{layer_idx}"

if symbol != "*":
_convert_hyena_layer(
mbridge_state_dict,
vortex_sd,
prefix,
block_prefix,
symbol,
te_enabled,
num_groups,
filter_order,
medium_conv_len,
)
else:
_convert_attention_layer(
mbridge_state_dict,
vortex_sd,
prefix,
block_prefix,
te_enabled,
)

_convert_mlp(mbridge_state_dict, vortex_sd, prefix, block_prefix, te_enabled)

final_norm_key = "decoder.final_norm.weight"
if final_norm_key in mbridge_state_dict:
vortex_sd["norm.scale"] = mbridge_state_dict[final_norm_key]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate required tensor keys before writing the export.

Every mapping here is optional, so a bad --model-size or --no-te choice can silently drop required tensors and still produce a .pt file plus config.json. For a format converter, that makes corruption very hard to catch. Please collect missing mandatory keys per layer and raise before saving.

Also applies to: 172-296

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/mbridge_to_vortex.py`
around lines 133 - 167, The export currently silently omits required tensors
(e.g., embedding.word_embeddings.weight, decoder.final_norm.weight and per-layer
weights produced by _convert_hyena_layer, _convert_attention_layer, and
_convert_mlp) so add a validation pass after the loop that defines the mandatory
target keys (embedding_layer.weight, unembed.weight, norm.scale and all expected
per-layer keys derived from prefix/block_prefix and the pattern) and check their
presence in the resulting vortex_sd (or mbridge_state_dict if conversions expect
source keys); collect missing keys into a list and raise a descriptive exception
listing layer and key names (including references to the layer index and symbol)
before writing the .pt/config.json to fail fast on bad --model-size or --no-te
choices.

Comment on lines +39 to +79
@pytest.fixture(scope="module")
def savanna_checkpoint_path(tmp_path_factory):
"""Download the 1b savanna checkpoint from HuggingFace."""
cache_dir = tmp_path_factory.mktemp("savanna_ckpt")
path = hf_hub_download(
repo_id=SAVANNA_1B_REPO,
filename="savanna_evo2_1b_base.pt",
local_dir=str(cache_dir),
)
return path


@pytest.fixture(scope="module")
def vortex_reference_path(tmp_path_factory):
"""Download the 1b vortex checkpoint from HuggingFace."""
cache_dir = tmp_path_factory.mktemp("vortex_ref")
path = hf_hub_download(
repo_id=VORTEX_1B_REPO,
filename="evo2_1b_base.pt",
local_dir=str(cache_dir),
)
return path


@pytest.fixture(scope="module")
def roundtrip_vortex_sd(savanna_checkpoint_path):
"""Perform savanna -> mbridge -> vortex conversion and return the vortex state dict."""
provider_cls = HYENA_MODEL_OPTIONS[MODEL_SIZE]
model_provider = provider_cls()
pattern = model_provider.hybrid_override_pattern

savanna_sd = load_savanna_state_dict(savanna_checkpoint_path)
mbridge_sd = savanna_to_mbridge_state_dict(savanna_sd, pattern, te_enabled=True)
vortex_sd = mbridge_to_vortex_state_dict(mbridge_sd, model_provider, te_enabled=True)
return vortex_sd


@pytest.fixture(scope="module")
def vortex_reference_sd(vortex_reference_path):
"""Load the reference vortex state dict from HuggingFace."""
return torch.load(vortex_reference_path, map_location="cpu", weights_only=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate the file and verify it exists
find . -type f -name "test_checkpoint_roundtrip.py" | head -20

Repository: NVIDIA/bionemo-framework

Length of output: 155


🏁 Script executed:

# Check if pyproject.toml exists and review relevant configs
find . -type f -name "pyproject.toml" | head -5

Repository: NVIDIA/bionemo-framework

Length of output: 294


🏁 Script executed:

# Read the test file to verify the code snippet
find . -type f -name "test_checkpoint_roundtrip.py" -exec cat -n {} \; 2>/dev/null | head -100

Repository: NVIDIA/bionemo-framework

Length of output: 4683


🏁 Script executed:

# Check if these tests are actually run in CI
grep -r "CI\|GITHUB_ACTIONS\|skipif" "./bionemo-recipes/recipes/evo2_megatron/tests/" --include="*.py" | head -20

Repository: NVIDIA/bionemo-framework

Length of output: 3359


🏁 Script executed:

# Look for other hf_hub_download usage patterns in the codebase
rg "hf_hub_download" -A 5 -B 2 --type py | head -80

Repository: NVIDIA/bionemo-framework

Length of output: 9229


🏁 Script executed:

# Check for torch.load patterns with weights_only in the codebase
rg "torch\.load.*weights_only" -B 2 -A 2 --type py | head -60

Repository: NVIDIA/bionemo-framework

Length of output: 6234


🏁 Script executed:

# Look at the test markers to understand test context
grep -n "pytest.mark\|@pytest" "./bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_checkpoint_roundtrip.py"

Repository: NVIDIA/bionemo-framework

Length of output: 542


🏁 Script executed:

# Check the load_savanna_state_dict implementation
rg "def load_savanna_state_dict" -A 15 --type py

Repository: NVIDIA/bionemo-framework

Length of output: 2041


🏁 Script executed:

# Look at the full test to understand what comparisons are being done
sed -n '82,150p' "./bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_checkpoint_roundtrip.py"

Repository: NVIDIA/bionemo-framework

Length of output: 2470


Pin HuggingFace checkpoints to immutable revisions and load reference checkpoint safely.

The fixtures download Savanna and Vortex checkpoints without specifying a revision=, so the golden data drifts if upstream repos change. Additionally, vortex_reference_sd() uses weights_only=False when loading a remote .pt file, unnecessarily allowing pickle code execution in the test. Add explicit commit SHAs to both hf_hub_download() calls and change the torch.load to use weights_only=True for the reference fixture (or use safetensors format if available).

Example fix
 path = hf_hub_download(
     repo_id=SAVANNA_1B_REPO,
     filename="savanna_evo2_1b_base.pt",
+    revision="<commit-sha>",
     local_dir=str(cache_dir),
 )

 path = hf_hub_download(
     repo_id=VORTEX_1B_REPO,
     filename="evo2_1b_base.pt",
+    revision="<commit-sha>",
     local_dir=str(cache_dir),
 )

-    return torch.load(vortex_reference_path, map_location="cpu", weights_only=False)
+    return torch.load(vortex_reference_path, map_location="cpu", weights_only=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_checkpoint_roundtrip.py`
around lines 39 - 79, The fixtures download mutable HuggingFace checkpoints and
allow unsafe pickle loading; update savanna_checkpoint_path and
vortex_reference_path to pass explicit immutable commit SHAs via the revision=
parameter in their hf_hub_download(...) calls (use the specific commit SHA for
SAVANNA_1B_REPO and VORTEX_1B_REPO to pin the golden data), and modify
vortex_reference_sd to load the reference safely by calling torch.load(...,
map_location="cpu", weights_only=True) (or prefer safetensors if a .safetensors
artifact exists) so remote .pt files cannot execute pickle code.

Comment on lines +148 to +186
@pytest.mark.slow
def test_roundtrip_prediction_equality(
eden_ckpt: Path,
hf_exported_dir: Path,
hf_reimported_dir: Path,
tmp_path,
):
"""Verify that predictions from the original and roundtripped models match.

Runs predict on both the original mbridge checkpoint and on the re-imported HF checkpoint
(loaded via AutoBridge) and compares per-token log probabilities.
"""
num_sequences = 2
seq_lengths = [64, 64]

fasta_path = tmp_path / "test.fasta"
create_fasta_file(fasta_path, num_sequences, sequence_lengths=seq_lengths, repeating_dna_pattern=ALU_SEQUENCE)

env = copy.deepcopy(PRETEST_ENV)
if is_a6000_gpu():
env["NCCL_P2P_DISABLE"] = "1"

# Predictions from the original mbridge checkpoint
original_preds = _run_predict(eden_ckpt, fasta_path, tmp_path / "orig_preds", env)

assert "log_probs_seqs" in original_preds
assert "seq_idx" in original_preds

# Load the original and reimported HF models and compare forward pass
original_hf = LlamaForCausalLM.from_pretrained(hf_exported_dir, torch_dtype=torch.bfloat16).eval()
reimported_hf = LlamaForCausalLM.from_pretrained(hf_reimported_dir, torch_dtype=torch.bfloat16).eval()

# Quick sanity: HF forward pass should produce identical outputs for both
input_ids = torch.randint(0, 256, (1, 32))
with torch.no_grad():
orig_logits = original_hf(input_ids).logits
reimp_logits = reimported_hf(input_ids).logits

torch.testing.assert_close(orig_logits, reimp_logits, atol=0, rtol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Test computes original_preds but only compares HF models to each other.

The test runs _run_predict on eden_ckpt (lines 170-174) but only asserts that keys exist. The actual comparison (lines 183-186) is between original_hf and reimported_hf forward passes, not involving original_preds.

If the intent is to verify that the roundtripped model produces the same predictions as the original mbridge checkpoint, the comparison should include original_preds. Otherwise, the docstring claim "compares per-token log probabilities" is misleading since only HF-to-HF comparison occurs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py`
around lines 148 - 186, The test currently computes original_preds via
_run_predict(eden_ckpt, ...) but then only compares original_hf and
reimported_hf logits; change the test to run _run_predict on the roundtripped HF
checkpoint (use hf_reimported_dir) to produce hf_preds and then compare
original_preds["log_probs_seqs"] to hf_preds["log_probs_seqs"] (or the
equivalent per-token log-prob key) using a numeric assert (e.g.,
torch.testing.assert_close or numpy.testing.assert_allclose) so the comparison
actually verifies the roundtrip predictions; locate and update the block that
currently creates original_hf/reimported_hf and the final
torch.testing.assert_close to instead call _run_predict for hf_reimported_dir
(or both HF and eden if you want) and perform the assertion on the
"log_probs_seqs" entries.

jstjohn added 8 commits March 11, 2026 00:13
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants